Shrinkage

Suppose we observe \[X_1, \ldots, X_n \sim \text{Normal}\left(\begin{bmatrix} 2 \\ 3 \\ 5 \end{bmatrix}, \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1\end{bmatrix}\right)\]

For simplicity, suppose the covariance matrix is known but the mean vector \(\mu = (2, 3, 5)'\) is unknown; we want to estimate this parameter.

Here are three estimators we might choose from:

Bayesian, Normal Prior

If we use a prior of

\[\mu \sim \text{Normal}\left(\gamma_0, \Lambda_0\right),\]

then the posterior distribution for \(\mu\) is

\[\mu | x_1, \ldots, x_n, \gamma_0, \Lambda_0 \sim \text{Normal}\left((\Lambda_0^{-1} + n \Sigma^{-1})^{-1}(\Lambda_0^{-1} \gamma_0 + n \Sigma^{-1} \bar{x}), (\Lambda_0^{-1} + n \Sigma^{-1})^{-1} \right)\]

What is this posterior doing?

Suppose we use a prior mean of \(\gamma_0 = (0, 0, 0)'\) and a prior covariance of \(0.1 \mathbb{I}\). As the sample size \(n\) increases, the posterior mean moves towards the sample mean. For illustration, let’s see what happens if we imagine holding the sample mean fixed at \((2.1, 2.8, 4.7)'\) while increasing the sample size (so imagine taking larger sample sizes, but we happen to always get the same sample mean).

In the plot below, the blue point is at the mean of the prior distribution for \(\mu\), the red point is at the sample mean, and the black point is at the mean of the posterior distribution for \(\mu\). The slider controls the sample size \(n\).

James-Stein

The James-Stein Estimator is

\(\hat{\mu} = \left(1 - \frac{\sigma}{n\sum_{j = 1}^n\bar{x}_j^2} \right)\bar{x}\)

Effect of Shrinkage

Bayes, Prior Mean 0

The proportion of the surface of the sphere that moves closer to the center after shrinkage is approximately…

mean(dist_diff_to_plot < 0)
## [1] 0.5782878

Bayes, Prior Mean (2, 3, 5)

The proportion of the surface of the sphere that moves closer to the center after shrinkage is approximately…

mean(dist_diff_to_plot < 0)
## [1] 1

James-Stein

The proportion of the surface of the sphere that moves closer to the center after shrinkage is approximately…

mean(dist_diff_to_plot < 0)
## [1] 0.5403646